{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SDE Drag pipeline\n",
    "\n",
    "This pipeline provides drag-and-drop image editing using stochastic differential equations. It enables image editing by inputting prompt, image, mask_image, source_points, and target_points.See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information. This script was contributed by [Fengqi Zhu](https://github.com/MarkRich) and [NieShen](https://github.com/NieShenRuc).The notebook contributed by [Parag Ekbote](https://github.com/ParagEkbote)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: diffusers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.31.0)\n",
      "Requirement already satisfied: torch in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (2.0.1+cu117)\n",
      "Requirement already satisfied: pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (9.0.0)\n",
      "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (2.32.3)\n",
      "Requirement already satisfied: torchvision in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.15.2+cu117)\n",
      "Requirement already satisfied: importlib-metadata in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (8.5.0)\n",
      "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (3.16.1)\n",
      "Requirement already satisfied: huggingface-hub>=0.23.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.26.3)\n",
      "Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (1.26.4)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2024.11.6)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.4.5)\n",
      "Requirement already satisfied: typing-extensions in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (1.13.3)\n",
      "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.1.4)\n",
      "Requirement already satisfied: triton==2.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.0.0)\n",
      "Requirement already satisfied: cmake in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from triton==2.0.0->torch) (3.25.0)\n",
      "Requirement already satisfied: lit in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from triton==2.0.0->torch) (15.0.7)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests) (2024.8.30)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (2024.9.0)\n",
      "Requirement already satisfied: packaging>=20.9 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (24.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (6.0.2)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (4.66.6)\n",
      "Requirement already satisfied: zipp>=3.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.21.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch) (3.0.2)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install diffusers torch pillow requests torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5a56f794d024c3cbd2f7ca8a647607d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2134f5d940ed4b9cab95cdee644c9de8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "SDE Drag:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output image saved as './output.png'\n",
      "Output type: <class 'numpy.ndarray'>\n",
      "Output shape: (512, 512, 3)\n",
      "Output dtype: uint8\n",
      "Output min/max values: 0, 255\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from diffusers import DDIMScheduler, DiffusionPipeline\n",
    "from PIL import Image\n",
    "import requests\n",
    "from io import BytesIO\n",
    "import numpy as np\n",
    "\n",
    "# Load the pipeline\n",
    "model_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n",
    "scheduler = DDIMScheduler.from_pretrained(model_path, subfolder=\"scheduler\")\n",
    "pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline=\"sde_drag\")\n",
    "\n",
    "# Ensure the model is moved to the GPU\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "pipe.to(device)\n",
    "\n",
    "# Function to load image from URL\n",
    "def load_image_from_url(url):\n",
    "    response = requests.get(url)\n",
    "    return Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
    "\n",
    "# Function to prepare mask\n",
    "def prepare_mask(mask_image):\n",
    "    # Convert to grayscale\n",
    "    mask = mask_image.convert(\"L\")\n",
    "    return mask\n",
    "\n",
    "# Function to convert numpy array to PIL Image\n",
    "def array_to_pil(array):\n",
    "    # Ensure the array is in uint8 format\n",
    "    if array.dtype != np.uint8:\n",
    "        if array.max() <= 1.0:\n",
    "            array = (array * 255).astype(np.uint8)\n",
    "        else:\n",
    "            array = array.astype(np.uint8)\n",
    "    \n",
    "    # Handle different array shapes\n",
    "    if len(array.shape) == 3:\n",
    "        if array.shape[0] == 3:  # If channels first\n",
    "            array = array.transpose(1, 2, 0)\n",
    "        return Image.fromarray(array)\n",
    "    elif len(array.shape) == 4:  # If batch dimension\n",
    "        array = array[0]\n",
    "        if array.shape[0] == 3:  # If channels first\n",
    "            array = array.transpose(1, 2, 0)\n",
    "        return Image.fromarray(array)\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected array shape: {array.shape}\")\n",
    "\n",
    "# Image and mask URLs\n",
    "image_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'\n",
    "mask_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png'\n",
    "\n",
    "# Load the images\n",
    "image = load_image_from_url(image_url)\n",
    "mask_image = load_image_from_url(mask_url)\n",
    "\n",
    "# Resize images to a size that's compatible with the model's latent space\n",
    "image = image.resize((512, 512))\n",
    "mask_image = mask_image.resize((512, 512))\n",
    "\n",
    "# Prepare the mask (keep as PIL Image)\n",
    "mask = prepare_mask(mask_image)\n",
    "\n",
    "# Provide the prompt and points for drag editing\n",
    "prompt = \"A cute dog\"\n",
    "source_points = [[32, 32]]  # Adjusted for 512x512 image\n",
    "target_points = [[64, 64]]  # Adjusted for 512x512 image\n",
    "\n",
    "# Generate the output image\n",
    "output_array = pipe(\n",
    "    prompt=prompt,\n",
    "    image=image,\n",
    "    mask_image=mask,\n",
    "    source_points=source_points,\n",
    "    target_points=target_points\n",
    ")\n",
    "\n",
    "# Convert output array to PIL Image and save\n",
    "output_image = array_to_pil(output_array)\n",
    "output_image.save(\"./output.png\")\n",
    "print(\"Output image saved as './output.png'\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
